{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DjUA6S30k52h"
      },
      "source": [
        "##### Copyright 2021 The TensorFlow Authors."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "SpNWyqewk8fE"
      },
      "outputs": [],
      "source": [
        "#@title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "6x1ypzczQCwy"
      },
      "source": [
        "# Feature Engineering using TFX Pipeline and TensorFlow Transform\n",
        "\n",
        "***Transform input data and traing a model with a TFX pipeline.***"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HU9YYythm0dx"
      },
      "source": [
        "Note: We recommend running this tutorial in a Colab notebook, with no setup required!  Just click \"Run in Google Colab\".\n",
        "\n",
        "<div class=\"devsite-table-wrapper\"><table class=\"tfo-notebook-buttons\" align=\"left\">\n",
        "<td><a target=\"_blank\" href=\"https://www.tensorflow.org/tfx/tutorials/tfx/penguin_tft\">\n",
        "<img src=\"https://www.tensorflow.org/images/tf_logo_32px.png\"/>View on TensorFlow.org</a></td>\n",
        "<td><a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tfx/blob/master/docs/tutorials/tfx/penguin_tft.ipynb\">\n",
        "<img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\">Run in Google Colab</a></td>\n",
        "<td><a target=\"_blank\" href=\"https://github.com/tensorflow/tfx/tree/master/docs/tutorials/tfx/penguin_tft.ipynb\">\n",
        "<img width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\">View source on GitHub</a></td>\n",
        "<td><a href=\"https://storage.googleapis.com/tensorflow_docs/tfx/docs/tutorials/tfx/penguin_tft.ipynb\"><img src=\"https://www.tensorflow.org/images/download_logo_32px.png\" />Download notebook</a></td>\n",
        "</table></div>"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "_VuwrlnvQJ5k"
      },
      "source": [
        "In this notebook-based tutorial, we will create and run a TFX pipeline\n",
        "to ingest raw input data and preprocess it appropriately for ML training.\n",
        "This notebook is based on the TFX pipeline we built in\n",
        "[Data validation using TFX Pipeline and TensorFlow Data Validation Tutorial](https://www.tensorflow.org/tfx/tutorials/tfx/penguin_tfdv).\n",
        "If you have not read that one yet, you should read it before proceeding with\n",
        "this notebook.\n",
        "\n",
        "You can increase the predictive quality of your data and/or reduce\n",
        "dimensionality with feature engineering. One of the benefits of using TFX is\n",
        "that you will write your transformation code once, and the resulting transforms\n",
        "will be consistent between training and serving in\n",
        "order to avoid training/serving skew.\n",
        "\n",
        "We will add a `Transform` component to the pipeline. The Transform component is\n",
        "implemented using the\n",
        "[tf.transform](https://www.tensorflow.org/tfx/transform/get_started) library.\n",
        "\n",
        "Please see\n",
        "[Understanding TFX Pipelines](https://www.tensorflow.org/tfx/guide/understanding_tfx_pipelines)\n",
        "to learn more about various concepts in TFX."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Fmgi8ZvQkScg"
      },
      "source": [
        "## Set Up\n",
        "We first need to install the TFX Python package and download\n",
        "the dataset which we will use for our model.\n",
        "\n",
        "### Upgrade Pip\n",
        "\n",
        "To avoid upgrading Pip in a system when running locally,\n",
        "check to make sure that we are running in Colab.\n",
        "Local systems can of course be upgraded separately."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "as4OTe2ukSqm"
      },
      "outputs": [],
      "source": [
        "try:\n",
        "  import colab\n",
        "  !pip install --upgrade pip\n",
        "except:\n",
        "  pass"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "MZOYTt1RW4TK"
      },
      "source": [
        "### Install TFX\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "iyQtljP-qPHY"
      },
      "outputs": [],
      "source": [
        "!pip install -U tfx"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EwT0nov5QO1M"
      },
      "source": [
        "### Did you restart the runtime?\n",
        "\n",
        "If you are using Google Colab, the first time that you run\n",
        "the cell above, you must restart the runtime by clicking\n",
        "above \"RESTART RUNTIME\" button or using \"Runtime > Restart\n",
        "runtime ...\" menu. This is because of the way that Colab\n",
        "loads packages."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BDnPgN8UJtzN"
      },
      "source": [
        "Check the TensorFlow and TFX versions."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6jh7vKSRqPHb"
      },
      "outputs": [],
      "source": [
        "import tensorflow as tf\n",
        "print('TensorFlow version: {}'.format(tf.__version__))\n",
        "from tfx import v1 as tfx\n",
        "print('TFX version: {}'.format(tfx.__version__))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "aDtLdSkvqPHe"
      },
      "source": [
        "### Set up variables\n",
        "\n",
        "There are some variables used to define a pipeline. You can customize these\n",
        "variables as you want. By default all output from the pipeline will be\n",
        "generated under the current directory."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EcUseqJaE2XN"
      },
      "outputs": [],
      "source": [
        "import os\n",
        "\n",
        "PIPELINE_NAME = \"penguin-transform\"\n",
        "\n",
        "# Output directory to store artifacts generated from the pipeline.\n",
        "PIPELINE_ROOT = os.path.join('pipelines', PIPELINE_NAME)\n",
        "# Path to a SQLite DB file to use as an MLMD storage.\n",
        "METADATA_PATH = os.path.join('metadata', PIPELINE_NAME, 'metadata.db')\n",
        "# Output directory where created models from the pipeline will be exported.\n",
        "SERVING_MODEL_DIR = os.path.join('serving_model', PIPELINE_NAME)\n",
        "\n",
        "from absl import logging\n",
        "logging.set_verbosity(logging.INFO)  # Set default logging level."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "qsO0l5F3dzOr"
      },
      "source": [
        "### Prepare example data\n",
        "We will download the example dataset for use in our TFX pipeline. The dataset\n",
        "we are using is\n",
        "[Palmer Penguins dataset](https://allisonhorst.github.io/palmerpenguins/articles/intro.html).\n",
        "\n",
        "However, unlike previous tutorials which used an already preprocessed dataset,\n",
        "we will use the **raw** Palmer Penguins dataset.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "11J7XiCq6AFP"
      },
      "source": [
        "Because the TFX ExampleGen component reads inputs from a directory, we need\n",
        "to create a directory and copy the dataset to it."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4fxMs6u86acP"
      },
      "outputs": [],
      "source": [
        "import urllib.request\n",
        "import tempfile\n",
        "\n",
        "DATA_ROOT = tempfile.mkdtemp(prefix='tfx-data')  # Create a temporary directory.\n",
        "_data_path = 'https://storage.googleapis.com/download.tensorflow.org/data/palmer_penguins/penguins_size.csv'\n",
        "_data_filepath = os.path.join(DATA_ROOT, \"data.csv\")\n",
        "urllib.request.urlretrieve(_data_path, _data_filepath)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ASpoNmxKSQjI"
      },
      "source": [
        "Take a quick look at what the raw data looks like."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-eSz28UDSnlG"
      },
      "outputs": [],
      "source": [
        "!head {_data_filepath}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "OTtQNq1DdVvG"
      },
      "source": [
        "There are some entries with missing values which are represented as `NA`.\n",
        "We will just delete those entries in this tutorial."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fQhpoaqff9ca"
      },
      "outputs": [],
      "source": [
        "!sed -i '/\\bNA\\b/d' {_data_filepath}\n",
        "!head {_data_filepath}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "z8EOfCy1dzO2"
      },
      "source": [
        "You should be able to see seven features which describe penguins. We will use\n",
        "the same set of features as the previous tutorials - 'culmen_length_mm',\n",
        "'culmen_depth_mm', 'flipper_length_mm', 'body_mass_g' - and will predict\n",
        "the 'species' of a penguin.\n",
        "\n",
        "**The only difference will be that the input data is not preprocessed.** Note\n",
        "that we will not use other features like 'island' or 'sex' in this tutorial."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Jtbrkjjc-IKA"
      },
      "source": [
        "### Prepare a schema file\n",
        "\n",
        "As described in\n",
        "[Data validation using TFX Pipeline and TensorFlow Data Validation Tutorial](https://www.tensorflow.org/tfx/tutorials/tfx/penguin_tfdv),\n",
        "we need a schema file for the dataset. Because the dataset is different from the previous tutorial we need to generate it again. In this tutorial, we will skip those steps and just use a prepared schema file.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EDoB97m8B9nG"
      },
      "outputs": [],
      "source": [
        "import shutil\n",
        "\n",
        "SCHEMA_PATH = 'schema'\n",
        "\n",
        "_schema_uri = 'https://raw.githubusercontent.com/tensorflow/tfx/master/tfx/examples/penguin/schema/raw/schema.pbtxt'\n",
        "_schema_filename = 'schema.pbtxt'\n",
        "_schema_filepath = os.path.join(SCHEMA_PATH, _schema_filename)\n",
        "\n",
        "os.makedirs(SCHEMA_PATH, exist_ok=True)\n",
        "urllib.request.urlretrieve(_schema_uri, _schema_filepath)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "gKJ_HDJQB94b"
      },
      "source": [
        "This schema file was created with the same pipeline as in the previous tutorial\n",
        "without any manual changes."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "nH6gizcpSwWV"
      },
      "source": [
        "## Create a pipeline\n",
        "\n",
        "TFX pipelines are defined using Python APIs. We will add `Transform`\n",
        "component to the pipeline we created in the\n",
        "[Data Validation tutorial](https://www.tensorflow.org/tfx/tutorials/tfx/penguin_tfdv).\n",
        "\n",
        "A Transform component requires input data from an `ExampleGen` component and\n",
        "a schema from a `SchemaGen` component, and produces a \"transform graph\". The\n",
        "output will be used in a `Trainer` component. Transform can optionally\n",
        "produce \"transformed data\" in addition, which is the materialized data after\n",
        "transformation.\n",
        "However, we will transform data during training in this tutorial without\n",
        "materialization of the intermediate transformed data.\n",
        "\n",
        "One thing to note is that we need to define a Python function,\n",
        "`preprocessing_fn` to describe how input data should be transformed. This is\n",
        "similar to a Trainer component which also requires user code for model\n",
        "definition.\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "lOjDv93eS5xV"
      },
      "source": [
        "### Write preprocessing and training code\n",
        "\n",
        "We need to define two Python functions. One for Transform and one for Trainer.\n",
        "\n",
        "#### preprocessing_fn\n",
        "The Transform component will find a function named `preprocessing_fn` in the\n",
        "given module file as we did for `Trainer` component. You can also specify a\n",
        "specific function using the\n",
        "[`preprocessing_fn` parameter](https://github.com/tensorflow/tfx/blob/142de6e887f26f4101ded7925f60d7d4fe9d42ed/tfx/components/transform/component.py#L113)\n",
        "of the Transform component.\n",
        "\n",
        "In this example, we will do two kinds of transformation. For continuous numeric\n",
        "features like `culmen_length_mm` and `body_mass_g`, we will normalize these\n",
        "values using the\n",
        "[tft.scale_to_z_score](https://www.tensorflow.org/tfx/transform/api_docs/python/tft/scale_to_z_score)\n",
        "function. For the label feature, we need to convert string labels into numeric\n",
        "index values. We will use\n",
        "[`tf.lookup.StaticHashTable`](https://www.tensorflow.org/api_docs/python/tf/lookup/StaticHashTable)\n",
        "for conversion.\n",
        "\n",
        "To identify transformed fields easily, we append a `_xf` suffix to the\n",
        "transformed feature names.\n",
        "\n",
        "#### run_fn\n",
        "\n",
        "The model itself is almost the same as in the previous tutorials, but this time\n",
        "we will transform the input data using the transform graph from the Transform\n",
        "component.\n",
        "\n",
        "One more important difference compared to the previous tutorial is that now we\n",
        "export a model for serving which includes not only the computation graph of the\n",
        "model, but also the transform graph for preprocessing, which is generated in\n",
        "Transform component. We need to define a separate function which will be used\n",
        "for serving incoming requests. You can see that the same function\n",
        "`_apply_preprocessing` was used for both of the training data and the\n",
        "serving request.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "aES7Hv5QTDK3"
      },
      "outputs": [],
      "source": [
        "_module_file = 'penguin_utils.py'"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Gnc67uQNTDfW"
      },
      "outputs": [],
      "source": [
        "%%writefile {_module_file}\n",
        "\n",
        "\n",
        "from typing import List, Text\n",
        "from absl import logging\n",
        "import tensorflow as tf\n",
        "from tensorflow import keras\n",
        "from tensorflow_metadata.proto.v0 import schema_pb2\n",
        "import tensorflow_transform as tft\n",
        "from tensorflow_transform.tf_metadata import schema_utils\n",
        "\n",
        "from tfx import v1 as tfx\n",
        "from tfx_bsl.public import tfxio\n",
        "\n",
        "# Specify features that we will use.\n",
        "_FEATURE_KEYS = [\n",
        "    'culmen_length_mm', 'culmen_depth_mm', 'flipper_length_mm', 'body_mass_g'\n",
        "]\n",
        "_LABEL_KEY = 'species'\n",
        "\n",
        "_TRAIN_BATCH_SIZE = 20\n",
        "_EVAL_BATCH_SIZE = 10\n",
        "\n",
        "\n",
        "# NEW: Transformed features will have '_xf' suffix.\n",
        "def _transformed_name(key):\n",
        "  return key + '_xf'\n",
        "\n",
        "\n",
        "# NEW: TFX Transform will call this function.\n",
        "def preprocessing_fn(inputs):\n",
        "  \"\"\"tf.transform's callback function for preprocessing inputs.\n",
        "\n",
        "  Args:\n",
        "    inputs: map from feature keys to raw not-yet-transformed features.\n",
        "\n",
        "  Returns:\n",
        "    Map from string feature key to transformed feature.\n",
        "  \"\"\"\n",
        "  outputs = {}\n",
        "\n",
        "  # Uses features defined in _FEATURE_KEYS only.\n",
        "  for key in _FEATURE_KEYS:\n",
        "    # tft.scale_to_z_score computes the mean and variance of the given feature\n",
        "    # and scales the output based on the result.\n",
        "    outputs[_transformed_name(key)] = tft.scale_to_z_score(inputs[key])\n",
        "\n",
        "  # For the label column we provide the mapping from string to index.\n",
        "  # We could instead use `tft.compute_and_apply_vocabulary()` in order to\n",
        "  # compute the vocabulary dynamically and perform a lookup.\n",
        "  # Since in this example there are only 3 possible values, we use a hard-coded\n",
        "  # table for simplicity.\n",
        "  table_keys = ['Adelie', 'Chinstrap', 'Gentoo']\n",
        "  initializer = tf.lookup.KeyValueTensorInitializer(\n",
        "      keys=table_keys,\n",
        "      values=tf.cast(tf.range(len(table_keys)), tf.int64),\n",
        "      key_dtype=tf.string,\n",
        "      value_dtype=tf.int64)\n",
        "  table = tf.lookup.StaticHashTable(initializer, default_value=-1)\n",
        "  outputs[_transformed_name(_LABEL_KEY)] = table.lookup(inputs[_LABEL_KEY])\n",
        "\n",
        "  return outputs\n",
        "\n",
        "\n",
        "# NEW: This function will apply the same transform operation to training data\n",
        "#      and serving requests.\n",
        "def _apply_preprocessing(raw_features, tft_layer):\n",
        "  transformed_features = tft_layer(raw_features)\n",
        "  if _LABEL_KEY in raw_features:\n",
        "    transformed_label = transformed_features.pop(_transformed_name(_LABEL_KEY))\n",
        "    return transformed_features, transformed_label\n",
        "  else:\n",
        "    return transformed_features, None\n",
        "\n",
        "\n",
        "# NEW: This function will create a handler function which gets a serialized\n",
        "#      tf.example, preprocess and run an inference with it.\n",
        "def _get_serve_tf_examples_fn(model, tf_transform_output):\n",
        "  # We must save the tft_layer to the model to ensure its assets are kept and\n",
        "  # tracked.\n",
        "  model.tft_layer = tf_transform_output.transform_features_layer()\n",
        "\n",
        "  @tf.function(input_signature=[\n",
        "      tf.TensorSpec(shape=[None], dtype=tf.string, name='examples')\n",
        "  ])\n",
        "  def serve_tf_examples_fn(serialized_tf_examples):\n",
        "    # Expected input is a string which is serialized tf.Example format.\n",
        "    feature_spec = tf_transform_output.raw_feature_spec()\n",
        "    # Because input schema includes unnecessary fields like 'species' and\n",
        "    # 'island', we filter feature_spec to include required keys only.\n",
        "    required_feature_spec = {\n",
        "        k: v for k, v in feature_spec.items() if k in _FEATURE_KEYS\n",
        "    }\n",
        "    parsed_features = tf.io.parse_example(serialized_tf_examples,\n",
        "                                          required_feature_spec)\n",
        "\n",
        "    # Preprocess parsed input with transform operation defined in\n",
        "    # preprocessing_fn().\n",
        "    transformed_features, _ = _apply_preprocessing(parsed_features,\n",
        "                                                   model.tft_layer)\n",
        "    # Run inference with ML model.\n",
        "    return model(transformed_features)\n",
        "\n",
        "  return serve_tf_examples_fn\n",
        "\n",
        "\n",
        "def _input_fn(file_pattern: List[Text],\n",
        "              data_accessor: tfx.components.DataAccessor,\n",
        "              tf_transform_output: tft.TFTransformOutput,\n",
        "              batch_size: int = 200) -> tf.data.Dataset:\n",
        "  \"\"\"Generates features and label for tuning/training.\n",
        "\n",
        "  Args:\n",
        "    file_pattern: List of paths or patterns of input tfrecord files.\n",
        "    data_accessor: DataAccessor for converting input to RecordBatch.\n",
        "    tf_transform_output: A TFTransformOutput.\n",
        "    batch_size: representing the number of consecutive elements of returned\n",
        "      dataset to combine in a single batch\n",
        "\n",
        "  Returns:\n",
        "    A dataset that contains (features, indices) tuple where features is a\n",
        "      dictionary of Tensors, and indices is a single Tensor of label indices.\n",
        "  \"\"\"\n",
        "  dataset = data_accessor.tf_dataset_factory(\n",
        "      file_pattern,\n",
        "      tfxio.TensorFlowDatasetOptions(batch_size=batch_size),\n",
        "      schema=tf_transform_output.raw_metadata.schema)\n",
        "\n",
        "  transform_layer = tf_transform_output.transform_features_layer()\n",
        "  def apply_transform(raw_features):\n",
        "    return _apply_preprocessing(raw_features, transform_layer)\n",
        "\n",
        "  return dataset.map(apply_transform).repeat()\n",
        "\n",
        "\n",
        "def _build_keras_model() -> tf.keras.Model:\n",
        "  \"\"\"Creates a DNN Keras model for classifying penguin data.\n",
        "\n",
        "  Returns:\n",
        "    A Keras Model.\n",
        "  \"\"\"\n",
        "  # The model below is built with Functional API, please refer to\n",
        "  # https://www.tensorflow.org/guide/keras/overview for all API options.\n",
        "  inputs = [\n",
        "      keras.layers.Input(shape=(1,), name=_transformed_name(f))\n",
        "      for f in _FEATURE_KEYS\n",
        "  ]\n",
        "  d = keras.layers.concatenate(inputs)\n",
        "  for _ in range(2):\n",
        "    d = keras.layers.Dense(8, activation='relu')(d)\n",
        "  outputs = keras.layers.Dense(3)(d)\n",
        "\n",
        "  model = keras.Model(inputs=inputs, outputs=outputs)\n",
        "  model.compile(\n",
        "      optimizer=keras.optimizers.Adam(1e-2),\n",
        "      loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),\n",
        "      metrics=[keras.metrics.SparseCategoricalAccuracy()])\n",
        "\n",
        "  model.summary(print_fn=logging.info)\n",
        "  return model\n",
        "\n",
        "\n",
        "# TFX Trainer will call this function.\n",
        "def run_fn(fn_args: tfx.components.FnArgs):\n",
        "  \"\"\"Train the model based on given args.\n",
        "\n",
        "  Args:\n",
        "    fn_args: Holds args used to train the model as name/value pairs.\n",
        "  \"\"\"\n",
        "  tf_transform_output = tft.TFTransformOutput(fn_args.transform_output)\n",
        "\n",
        "  train_dataset = _input_fn(\n",
        "      fn_args.train_files,\n",
        "      fn_args.data_accessor,\n",
        "      tf_transform_output,\n",
        "      batch_size=_TRAIN_BATCH_SIZE)\n",
        "  eval_dataset = _input_fn(\n",
        "      fn_args.eval_files,\n",
        "      fn_args.data_accessor,\n",
        "      tf_transform_output,\n",
        "      batch_size=_EVAL_BATCH_SIZE)\n",
        "\n",
        "  model = _build_keras_model()\n",
        "  model.fit(\n",
        "      train_dataset,\n",
        "      steps_per_epoch=fn_args.train_steps,\n",
        "      validation_data=eval_dataset,\n",
        "      validation_steps=fn_args.eval_steps)\n",
        "\n",
        "  # NEW: Save a computation graph including transform layer.\n",
        "  signatures = {\n",
        "      'serving_default': _get_serve_tf_examples_fn(model, tf_transform_output),\n",
        "  }\n",
        "  model.save(fn_args.serving_model_dir, save_format='tf', signatures=signatures)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "blaw0rs-emEf"
      },
      "source": [
        "Now you have completed all of the preparation steps to build a TFX pipeline."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "w3OkNz3gTLwM"
      },
      "source": [
        "### Write a pipeline definition\n",
        "\n",
        "We define a function to create a TFX pipeline. A `Pipeline` object\n",
        "represents a TFX pipeline, which can be run using one of the pipeline\n",
        "orchestration systems that TFX supports.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "M49yYVNBTPd4"
      },
      "outputs": [],
      "source": [
        "def _create_pipeline(pipeline_name: str, pipeline_root: str, data_root: str,\n",
        "                     schema_path: str, module_file: str, serving_model_dir: str,\n",
        "                     metadata_path: str) -> tfx.dsl.Pipeline:\n",
        "  \"\"\"Implements the penguin pipeline with TFX.\"\"\"\n",
        "  # Brings data into the pipeline or otherwise joins/converts training data.\n",
        "  example_gen = tfx.components.CsvExampleGen(input_base=data_root)\n",
        "\n",
        "  # Computes statistics over data for visualization and example validation.\n",
        "  statistics_gen = tfx.components.StatisticsGen(\n",
        "      examples=example_gen.outputs['examples'])\n",
        "\n",
        "  # Import the schema.\n",
        "  schema_importer = tfx.dsl.Importer(\n",
        "      source_uri=schema_path,\n",
        "      artifact_type=tfx.types.standard_artifacts.Schema).with_id(\n",
        "          'schema_importer')\n",
        "\n",
        "  # Performs anomaly detection based on statistics and data schema.\n",
        "  example_validator = tfx.components.ExampleValidator(\n",
        "      statistics=statistics_gen.outputs['statistics'],\n",
        "      schema=schema_importer.outputs['result'])\n",
        "\n",
        "  # NEW: Transforms input data using preprocessing_fn in the 'module_file'.\n",
        "  transform = tfx.components.Transform(\n",
        "      examples=example_gen.outputs['examples'],\n",
        "      schema=schema_importer.outputs['result'],\n",
        "      materialize=False,\n",
        "      module_file=module_file)\n",
        "\n",
        "  # Uses user-provided Python function that trains a model.\n",
        "  trainer = tfx.components.Trainer(\n",
        "      module_file=module_file,\n",
        "      examples=example_gen.outputs['examples'],\n",
        "\n",
        "      # NEW: Pass transform_graph to the trainer.\n",
        "      transform_graph=transform.outputs['transform_graph'],\n",
        "\n",
        "      train_args=tfx.proto.TrainArgs(num_steps=100),\n",
        "      eval_args=tfx.proto.EvalArgs(num_steps=5))\n",
        "\n",
        "  # Pushes the model to a filesystem destination.\n",
        "  pusher = tfx.components.Pusher(\n",
        "      model=trainer.outputs['model'],\n",
        "      push_destination=tfx.proto.PushDestination(\n",
        "          filesystem=tfx.proto.PushDestination.Filesystem(\n",
        "              base_directory=serving_model_dir)))\n",
        "\n",
        "  components = [\n",
        "      example_gen,\n",
        "      statistics_gen,\n",
        "      schema_importer,\n",
        "      example_validator,\n",
        "\n",
        "      transform,  # NEW: Transform component was added to the pipeline.\n",
        "\n",
        "      trainer,\n",
        "      pusher,\n",
        "  ]\n",
        "\n",
        "  return tfx.dsl.Pipeline(\n",
        "      pipeline_name=pipeline_name,\n",
        "      pipeline_root=pipeline_root,\n",
        "      metadata_connection_config=tfx.orchestration.metadata\n",
        "      .sqlite_metadata_connection_config(metadata_path),\n",
        "      components=components)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mJbq07THU2GV"
      },
      "source": [
        "## Run the pipeline\n",
        "\n",
        "We will use `LocalDagRunner` as in the previous tutorial."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "fAtfOZTYWJu-"
      },
      "outputs": [],
      "source": [
        "tfx.orchestration.LocalDagRunner().run(\n",
        "  _create_pipeline(\n",
        "      pipeline_name=PIPELINE_NAME,\n",
        "      pipeline_root=PIPELINE_ROOT,\n",
        "      data_root=DATA_ROOT,\n",
        "      schema_path=SCHEMA_PATH,\n",
        "      module_file=_module_file,\n",
        "      serving_model_dir=SERVING_MODEL_DIR,\n",
        "      metadata_path=METADATA_PATH))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ppERq0Mj6xvW"
      },
      "source": [
        "You should see \"INFO:absl:Component Pusher is finished.\" if the pipeline\n",
        "finished successfully.\n",
        "\n",
        "The pusher component pushes the trained model to the `SERVING_MODEL_DIR` which\n",
        "is the `serving_model/penguin-transform` directory if you did not change\n",
        "the variables in the previous steps. You can see the result from the file\n",
        "browser in the left-side panel in Colab, or using the following command:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "NTHROkqX6yHx"
      },
      "outputs": [],
      "source": [
        "# List files in created model directory.\n",
        "!find {SERVING_MODEL_DIR}"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "VTqM-WiZkPbt"
      },
      "source": [
        "You can also check the signature of the generated model using the\n",
        "[`saved_model_cli` tool](https://www.tensorflow.org/guide/saved_model#show_command)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "YBfUzD_OkOq_"
      },
      "outputs": [],
      "source": [
        "!saved_model_cli show --dir {SERVING_MODEL_DIR}/$(ls -1 {SERVING_MODEL_DIR} | sort -nr | head -1) --tag_set serve --signature_def serving_default"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "DkAxFs_QszoZ"
      },
      "source": [
        "Because we defined `serving_default` with our own `serve_tf_examples_fn`\n",
        "function, the signature shows that it takes a single string.\n",
        "This string is a serialized string of tf.Examples and will be parsed with the\n",
        "[tf.io.parse_example()](https://www.tensorflow.org/api_docs/python/tf/io/parse_example)\n",
        "function as we defined earlier (learn more about tf.Examples [here](https://www.tensorflow.org/tutorials/load_data/tfrecord)).\n",
        "\n",
        "We can load the exported model and try some inferences with a few examples."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Z1Yw5yYdvqKf"
      },
      "outputs": [],
      "source": [
        "# Find a model with the latest timestamp.\n",
        "model_dirs = (item for item in os.scandir(SERVING_MODEL_DIR) if item.is_dir())\n",
        "model_path = max(model_dirs, key=lambda i: int(i.name)).path\n",
        "\n",
        "loaded_model = tf.keras.models.load_model(model_path)\n",
        "inference_fn = loaded_model.signatures['serving_default']"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "xrOHIvnIv0-4"
      },
      "outputs": [],
      "source": [
        "# Prepare an example and run inference.\n",
        "features = {\n",
        "  'culmen_length_mm': tf.train.Feature(float_list=tf.train.FloatList(value=[49.9])),\n",
        "  'culmen_depth_mm': tf.train.Feature(float_list=tf.train.FloatList(value=[16.1])),\n",
        "  'flipper_length_mm': tf.train.Feature(int64_list=tf.train.Int64List(value=[213])),\n",
        "  'body_mass_g': tf.train.Feature(int64_list=tf.train.Int64List(value=[5400])),\n",
        "}\n",
        "example_proto = tf.train.Example(features=tf.train.Features(feature=features))\n",
        "examples = example_proto.SerializeToString()\n",
        "\n",
        "result = inference_fn(examples=tf.constant([examples]))\n",
        "print(result['output_0'].numpy())"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "cri3mTgZ0SQ2"
      },
      "source": [
        "The third element, which corresponds to 'Gentoo' species, is expected to be the\n",
        "largest among three."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "08R8qvweThRf"
      },
      "source": [
        "## Next steps\n",
        "\n",
        "If you want to learn more about Transform component, see\n",
        "[Transform Component guide](https://www.tensorflow.org/tfx/guide/transform).\n",
        "You can find more resources on https://www.tensorflow.org/tfx/tutorials.\n",
        "\n",
        "Please see\n",
        "[Understanding TFX Pipelines](https://www.tensorflow.org/tfx/guide/understanding_tfx_pipelines)\n",
        "to learn more about various concepts in TFX.\n"
      ]
    }
  ],
  "metadata": {
    "colab": {
      "collapsed_sections": [
        "DjUA6S30k52h"
      ],
      "name": "penguin_tft.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
